import numpy as np
import torch
import math, json, ast
import torchvision.transforms as T
from PIL import Image
import cv2
import numpy as np
from torchvision.transforms.functional import InterpolationMode
from transformers import AutoModel, AutoTokenizer, AutoConfig
from utils.result_preprocess import Intern_VL_RES_PREPROCESS
IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD = (0.229, 0.224, 0.225)

class InternVL_Agent:
    def __init__(self, device, accelerator, cache_dir='~/.cache', dropout=0.5, policy_lm=None):
        self.model = None
        self.policy_lm = policy_lm
        self.tokenizer = AutoTokenizer.from_pretrained(self.policy_lm, trust_remote_code=True, use_fast=False)
        self.tokenizer.truncation_side = 'left'
        self.tokenizer.pad_token = self.tokenizer.eos_token
        self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
        self.device = device
        self.accelerator = accelerator
        self.res_pre_process = self._res_pre_process()
        self.generation_config = dict(
            num_beams=1,
            max_new_tokens=512,
            do_sample=False,
            temperature=0.0,
            top_p = 0
        )
    def _res_pre_process(self):
        return Intern_VL_RES_PREPROCESS()

    def split_model(self):
        device_map = {}
        world_size = torch.cuda.device_count()
        config = AutoConfig.from_pretrained(self.policy_lm, trust_remote_code=True)
        num_layers = config.llm_config.num_hidden_layers
        # Since the first GPU will be used for ViT, treat it as half a GPU.
        num_layers_per_gpu = math.ceil(num_layers / (world_size - 0.5))
        num_layers_per_gpu = [num_layers_per_gpu] * world_size
        num_layers_per_gpu[0] = math.ceil(num_layers_per_gpu[0] * 0.5)
        layer_cnt = 0
        for i, num_layer in enumerate(num_layers_per_gpu):
            for j in range(num_layer):
                device_map[f'language_model.model.layers.{layer_cnt}'] = i
                layer_cnt += 1
        device_map['vision_model'] = 0
        device_map['mlp1'] = 0
        device_map['language_model.model.tok_embeddings'] = 0
        device_map['language_model.model.embed_tokens'] = 0
        device_map['language_model.output'] = 0
        device_map['language_model.model.norm'] = 0
        device_map['language_model.model.rotary_emb'] = 0
        device_map['language_model.lm_head'] = 0
        device_map[f'language_model.model.layers.{num_layers - 1}'] = 0

        return device_map

    def _load_model(self):
        device_map = self.split_model()
        self.model = AutoModel.from_pretrained(
            self.policy_lm,
            device_map=device_map, 
            torch_dtype=torch.bfloat16, 
            trust_remote_code=True).eval()
        return self.model
       
    def build_transform(self, input_size):
        MEAN, STD = IMAGENET_MEAN, IMAGENET_STD
        transform = T.Compose([
            T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
            T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC),
            T.ToTensor(),
            T.Normalize(mean=MEAN, std=STD)
        ])
        return transform
    
    def find_closest_aspect_ratio(self, aspect_ratio, target_ratios, width, height, image_size):
        best_ratio_diff = float('inf')
        best_ratio = (1, 1)
        area = width * height
        for ratio in target_ratios:
            target_aspect_ratio = ratio[0] / ratio[1]
            ratio_diff = abs(aspect_ratio - target_aspect_ratio)
            if ratio_diff < best_ratio_diff:
                best_ratio_diff = ratio_diff
                best_ratio = ratio
            elif ratio_diff == best_ratio_diff:
                if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
                    best_ratio = ratio
        return best_ratio
    
    def dynamic_preprocess(self, image, min_num=1, max_num=12, image_size=448, use_thumbnail=False):
        orig_width, orig_height = image.size
        aspect_ratio = orig_width / orig_height

        # calculate the existing image aspect ratio
        target_ratios = set(
            (i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if
            i * j <= max_num and i * j >= min_num)
        target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])

        # find the closest aspect ratio to the target
        target_aspect_ratio = self.find_closest_aspect_ratio(
            aspect_ratio, target_ratios, orig_width, orig_height, image_size)

        # calculate the target width and height
        target_width = image_size * target_aspect_ratio[0]
        target_height = image_size * target_aspect_ratio[1]
        blocks = target_aspect_ratio[0] * target_aspect_ratio[1]

        # resize the image
        resized_img = image.resize((target_width, target_height))
        processed_images = []
        for i in range(blocks):
            box = (
                (i % (target_width // image_size)) * image_size,
                (i // (target_width // image_size)) * image_size,
                ((i % (target_width // image_size)) + 1) * image_size,
                ((i // (target_width // image_size)) + 1) * image_size
            )
            # split the image
            split_img = resized_img.crop(box)
            processed_images.append(split_img)
        assert len(processed_images) == blocks
        if use_thumbnail and len(processed_images) != 1:
            thumbnail_img = image.resize((image_size, image_size))
            processed_images.append(thumbnail_img)
        return processed_images
    
    def load_image(self, image_file, obs, args, input_size=448, max_num=12):
        image = Image.open(image_file).convert('RGB')
        label = None
        if args.probing_method == 'visual_mask':
            image = self.visual_mask(image, obs, args)
        elif args.probing_method == 'zoom':
            image, label = self.zoom_in(image, obs, args)
        else:
            image = self.visual_mask(image, obs, args)         
        transform = self.build_transform(input_size=input_size)
        images = self.dynamic_preprocess(image, image_size=input_size, use_thumbnail=True, max_num=max_num)
        pixel_values = [transform(image) for image in images]
        pixel_values = torch.stack(pixel_values)
        return pixel_values
    
    def get_action(self, obs, args):
        pixel_values = self.load_image(obs['images'][0], max_num=6, obs=obs, args=args)
        pixel_values = pixel_values.to(torch.bfloat16).cuda()
        question = obs['question']
        response, history = self.model.chat(
            tokenizer = self.tokenizer,
            pixel_values = pixel_values,
            question = question,
            generation_config = self.generation_config,
            history=None, 
            return_history=True)
        return response
    
    def visual_mask(self, image_input, obs, args):
        from PIL import ImageDraw
        draw = ImageDraw.Draw(image_input)
        image_width, image_height = image_input.size[0], image_input.size[1]
        if obs.get('dataset_name') == 'AndroidControl':
            accessibility_trees_file_path = obs['accessibility_trees']
            bbox_data = []
            with open(accessibility_trees_file_path, "r", encoding="utf-8") as file:
                for line in file:
                  try:
                      obj = json.loads(line)
                      bbox = obj.get("bbox_pixels", None)
                      class_name = obj.get("class_name", None)
                      if bbox and (class_name == 'android.widget.ImageButton' or class_name == 'android.widget.TextView' or class_name == 'android.widget.ImageView') and obj.get("is_clickable"):
                          x_min, y_min, x_max, y_max = bbox["x_min"], bbox["y_min"], bbox["x_max"], bbox["y_max"]
                          if (
                              0 <= x_min < x_max <= image_width and
                              0 <= y_min < y_max <= image_height
                          ):
                              bbox_data.append([x_min, y_min, x_max-x_min, y_max-y_min])
                  except Exception:
                      continue
            gt = self.res_pre_process.extract_action(obs.get('label'))
            gt = self.res_pre_process.extract_coordinates(gt)
        elif obs.get('dataset_name') == 'AITZ':
            accessibility_trees_file_path = obs['accessibility_trees']
            bbox_data = []
            with open(accessibility_trees_file_path, "r", encoding="utf-8") as file:
                accessibility_trees_file_data = json.load(file)
            for idx, acc_data in enumerate(accessibility_trees_file_data):
                if acc_data['image_path'] in obs.get('images')[0]:
                    bbox = ast.literal_eval(accessibility_trees_file_data[idx]['ui_positions'])
            bbox_data = [[y, x, h, w] for (x, y, w, h) in bbox]
            gt = self.res_pre_process.extract_action(obs['label'])
            gt = self.res_pre_process.extract_coordinates(gt)
        else:
            bbox_data = obs.get('bbox')
            bbox_data = [[bbox_data[0], bbox_data[1], bbox_data[2]-bbox_data[0], bbox_data[3]-bbox_data[1]]]
            gt = self.res_pre_process.extract_action(obs['label'])
            gt = self.res_pre_process.extract_coordinates(gt)
        _, bbox_list, point = self.remove_containing_bboxes(bbox_list=bbox_data, gt=gt, image_size=[image_width, image_height]) 
        if args.probing_method == 'visual_mask':
            if len(bbox_list) > 0:
                for bbox in bbox_list:
                    x, y, w, h = bbox
                    draw.rectangle([x, y, x+w, y+h], fill="black")
            else:
                r = args.mask_object_ratio
                draw.rectangle([point[0]-r, point[1]-r, point[0]+r, point[1]+r], fill="black")
        else:
            image_cv = np.array(image_input)
            image_input = cv2.cvtColor(image_cv, cv2.COLOR_RGB2BGR)
            if len(bbox_list) > 0:
                for bbox in bbox_list:
                    mask = np.zeros(image_input.shape[:2], dtype=np.uint8)
                    mask[int(bbox[1]):int(bbox[3]), int(bbox[0]):int(bbox[2])] = 255
                    image_input = cv2.inpaint(image_input, mask, 3, cv2.INPAINT_TELEA)
                # image_input = Image.fromarray(image_input)
            r = args.mask_object_ratio
            mask = np.zeros(image_input.shape[:2], dtype=np.uint8)
            x, y = point
            x_min = int(x - r)
            y_min = int(y - r)
            x_max = int(x + r)
            y_max = int(y + r)
            mask[y_min:y_max, x_min:x_max] = 255

            image_input = cv2.inpaint(image_input, mask, 3, cv2.INPAINT_TELEA)
            image_input = Image.fromarray(image_input)
        return image_input

    def remove_containing_bboxes(self, bbox_list, gt, image_size):
        click_x, click_y = gt[0] / 1000 * image_size[0], gt[1] / 1000 * image_size[1]
        out_bbox_list = []
        in_bbox_list = []
        if len(bbox_list) > 0:
            for bbox in bbox_list:
                x, y, w, h = bbox
                if not (x <= click_x <= x+w and y <= click_y <= y+h):
                    out_bbox_list.append(bbox)
                else:
                    in_bbox_list.append(bbox)
        return out_bbox_list, in_bbox_list, (click_x, click_y)
    
    def zoom_in(self, pil_image, obs, args):
        from PIL import Image
        try:
            content = obs['label']
        except (IndexError, KeyError, TypeError):
            raise ValueError("Invalid message format in obs")

        ground_truth = self.res_pre_process.extract_action(content)
        task = self.res_pre_process.get_action_type(ground_truth)
        bbox = obs.get("bbox")  # [x_min, y_min, x_max, y_max]

        w, h = pil_image.size

        if task == 1:
            click_x, click_y = self.res_pre_process.extract_coordinates(ground_truth)
            # click_x = click_x / 1000 * w
            # click_y = click_y / 1000 * h
            mid_x, mid_y = w // 2, h // 2
            if click_x < mid_x and click_y < mid_y:
                region = (0, 0, mid_x, mid_y)
            elif click_x >= mid_x and click_y < mid_y:
                region = (mid_x, 0, w, mid_y)
            elif click_x < mid_x and click_y >= mid_y:
                region = (0, mid_y, mid_x, h)
            else:
                region = (mid_x, mid_y, w, h)

            cropped = pil_image.crop(region)
            zoomed_image = cropped.resize((w, h), Image.LANCZOS)

            
            def transform_coord(x, y, region, w, h):
                rel_x, rel_y = x - region[0], y - region[1]
                scale_x = w / (region[2] - region[0])
                scale_y = h / (region[3] - region[1])
                new_x = int(rel_x * scale_x)
                new_y = int(rel_y * scale_y)
                return new_x, new_y

            new_click_x, new_click_y = transform_coord(click_x, click_y, region, w, h)
            if "OS_Atlas" in args.model_name:
                norm_click_x = new_click_x / w * 1000
                norm_click_y = new_click_y / h * 1000
                label = "action:\n"+f"CLICK <point>[[{norm_click_x}, {norm_click_y}]]</point>" 
            elif "OS_Genesis" in args.model_name:
                label = f'Low-level thought: action: {{"action_type": "click", "x": {new_click_x}, "y": {new_click_y}}}'

            new_bbox = None
            if bbox is not None:
                bbox = [bbox[0]/1000*w, bbox[1]/1000*h, bbox[2]/1000*w, bbox[3]/1000*h]
                x_min, y_min = transform_coord(bbox[0], bbox[1], region, w, h)
                x_max, y_max = transform_coord(bbox[2], bbox[3], region, w, h)
                new_bbox = [x_min/w*1000, y_min/h*1000, x_max/w*1000, y_max/h*1000]

            return zoomed_image, {"label": label, "bbox": new_bbox}

        return pil_image, None
        